#!/usr/bin/env python3
# coding: utf-8

import argparse
import logging
import logging.handlers
import os
import sys
import time

import django
from django.core import management
from django.db.utils import OperationalError

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
APP_DIR = os.path.join(BASE_DIR, 'apps')

os.chdir(APP_DIR)
sys.path.insert(0, APP_DIR)
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
django.setup()

logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

try:
    from jumpserver import const

    __version__ = const.VERSION
except ImportError as e:
    print("Not found __version__: {}".format(e))
    print("Python is: ")
    logging.info(sys.executable)
    __version__ = 'Unknown'
    sys.exit(1)

try:
    from jumpserver.const import CONFIG
    from common.utils.file import download_file
    from common.utils import make_dirs
except ImportError as e:
    print("Import error: {}".format(e))
    print("Could not find config file, `cp config_example.yml config.yml`")
    sys.exit(1)

os.environ["PYTHONIOENCODING"] = "UTF-8"

logging.basicConfig(
    format='%(asctime)s %(message)s', level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
)

logger = logging.getLogger()

try:
    make_dirs(os.path.join(BASE_DIR, "data", "static"))
    make_dirs(os.path.join(BASE_DIR, "data", "media"))
except:
    pass


def check_database_connection():
    for i in range(60):
        logging.info(f"Check database connection: {i}")
        try:
            management.call_command('check', '--database', 'default')
            logging.info("Database connect success")
            return
        except OperationalError:
            logging.info('Database not setup, retry')
        except Exception as exc:
            logging.error('Unexpect error occur: {}'.format(str(exc)))
        time.sleep(1)
    logging.error("Connection database failed, exit")
    sys.exit(10)


def expire_caches():
    try:
        management.call_command('expire_caches')
    except:
        pass


def perform_db_migrate():
    logging.info("Check database structure change ...")
    logging.info("Migrate model change to database ...")
    try:
        management.call_command('migrate')
    except Exception:
        logging.error('Perform migrate failed, exit', exc_info=True)
        sys.exit(11)


def collect_static():
    logging.info("Collect static files")
    try:
        management.call_command('collectstatic', '--no-input', '-c', verbosity=0, interactive=False)
        logging.info("Collect static files done")
    except:
        pass


def compile_i18n_file():
    django_mo_file = os.path.join(BASE_DIR, 'apps', 'locale', 'zh', 'LC_MESSAGES', 'django.mo')
    if os.path.exists(django_mo_file):
        return
    os.chdir(os.path.join(BASE_DIR, 'apps'))
    management.call_command('compilemessages', verbosity=0, interactive=False)
    logging.info("Compile i18n files done")


def download_ip_db():
    db_base_dir = os.path.join(APP_DIR, 'common', 'utils', 'ip')
    db_path_url_mapper = {
        ('geoip', 'GeoLite2-City.mmdb'): 'https://jms-pkg.oss-cn-beijing.aliyuncs.com/ip/GeoLite2-City.mmdb',
        ('ipip', 'ipipfree.ipdb'): 'https://jms-pkg.oss-cn-beijing.aliyuncs.com/ip/ipipfree.ipdb'
    }
    for p, src in db_path_url_mapper.items():
        path = os.path.join(db_base_dir, *p)
        if os.path.isfile(path) and os.path.getsize(path) > 1000:
            continue
        print("Download ip db: {}".format(path))
        download_file(src, path)


def install_builtin_applets():
    logging.info("Install builtin applets")
    try:
        management.call_command('install_builtin_applets', verbosity=0)
    except Exception as e:
        logging.error("Install builtin applets err: {}".format(e))


def upgrade_db():
    collect_static()
    perform_db_migrate()


def prepare():
    check_database_connection()
    upgrade_db()
    expire_caches()
    download_ip_db()
    install_builtin_applets()


def start_services():
    services = args.services if isinstance(args.services, list) else [args.services]
    if action == 'start' and {'all', 'web'} & set(services):
        prepare()

    start_args = []
    if args.daemon:
        start_args.append('--daemon')
    if args.worker:
        start_args.extend(['--worker', str(args.worker)])
    if args.force:
        start_args.append('--force')

    try:
        management.call_command(action, *services, *start_args)
    except KeyboardInterrupt:
        logging.info('Cancel ...')
        time.sleep(2)
    except Exception as exc:
        logging.error("Start service error {}: {}".format(services, exc))
        time.sleep(2)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="""
        Jumpserver service control tools;

        Example: \r\n

        %(prog)s start all -d;
        """
    )
    parser.add_argument(
        'action', type=str,
        choices=("start", "stop", "restart", "status", "upgrade_db", "collect_static"),
        help="Action to run"
    )
    parser.add_argument(
        "services", type=str, default='all', nargs="*",
        choices=("all", "web", "task"),
        help="The service to start",
    )
    parser.add_argument('-d', '--daemon', nargs="?", const=True)
    parser.add_argument('-w', '--worker', type=int, nargs="?")
    parser.add_argument('-f', '--force', nargs="?", const=True)

    args = parser.parse_args()

    action = args.action
    if action == "upgrade_db":
        upgrade_db()
    elif action == "collect_static":
        collect_static()
    else:
        start_services()
